import os
from tqdm import tqdm
from easydict import EasyDict as edict
import numpy as np
import random
import torch
from torch import nn
from procedures import Linear_Region_Collector, get_ntk_n, get_ntk_n_zen
from models       import get_cell_based_tiny_net, get_search_spaces, nas_super_nets
from pdb import set_trace as bp
from typing import List, Dict, Tuple, Any, Optional  # noqa 401

op2index_201 = {
    'none': 0,
    'skip_connect': 1,
    'nor_conv_1x1': 2,
    'nor_conv_3x3': 3,
    'avg_pool_3x3': 4
}

op2index_darts = {
    'none': 0,
    'skip_connect': 1,
    'sep_conv_3x3': 2,
    'sep_conv_5x5': 3,
    'dil_conv_3x3': 4,
    'dil_conv_5x5': 5,
    'avg_pool_3x3': 6,
    'max_pool_3x3': 7
}

INF = 1000

reward_type2index = {
    'accuracy': -1,
    'tra': 0,
    'exp': 1,
    'gen': 2,
}


def kaiming_normal_fanin_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)


def kaiming_normal_fanout_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)


def init_model(model, method='kaiming_norm_fanin'):
    if method == 'kaiming_norm_fanin':
        model.apply(kaiming_normal_fanin_init)
    elif method == 'kaiming_norm_fanout':
        model.apply(kaiming_normal_fanout_init)
    return model


class Buffer_Reward_Generator_ntk(object):
    def __init__(self, xargs, space_name, space_ops, dataset, dataset_val, class_num):
        # self.__super__()
        self.reward_type2index = reward_type2index
        self._reward_types = ["ntk", "region", "mse"]
        self._reward_sign = {"ntk": -1, "mse": -1, "region": -1} # ntk/mse: lower the better; region: higher the better
        self._buffers = {key: [] for key in self._reward_types}
        self._buffers_bad = [] # indicator of bad architectures
        self._buffers_change = {key: [] for key in self._reward_types}
        self._buffer_length = getattr(xargs, "te_buffer_size", 10)
        self._xargs = xargs
        self._xargs.init = 'kaiming_norm'
        self._xargs.batch_size = getattr(xargs, "batch_size", 64)
        self._xargs.repeat = getattr(xargs, "repeat", 3)
        self._space_name = space_name
        self._space_ops = space_ops
        self._loader = torch.utils.data.DataLoader(dataset, batch_size=self._xargs.batch_size, num_workers=0, pin_memory=True, drop_last=True, shuffle=True)
        self._loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=self._xargs.batch_size, num_workers=0, pin_memory=True, drop_last=True, shuffle=True)
        self._class_num = class_num
        # self._region_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=xargs.dataset, data_path=xargs.data_path, seed=xargs.rand_seed)
        # self._region_model = None
        if space_name == 'nas-bench-201':
            self._model_config = edict({'name': 'DARTS-V1', 'C': 3, 'N': 1, 'depth': -1, 'use_stem': True,
                                        'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : space_ops,
                                        'affine' : True, 'track_running_stats': True,
                                       })
            self._model_config_thin = edict({'name': 'DARTS-V1', 'C': 1, 'N': 1, 'depth': 1, 'use_stem': False,
                                             'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : space_ops,
                                             'affine'   : True, 'track_running_stats': True,
                                            })
        else:
            self._model_config = edict({'name': 'DARTS-V1',
                                        'C': 1, 'N': 1, 'depth': 2, 'use_stem': True, 'stem_multiplier': 1,
                                        'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': space_ops,
                                        'imagenet': False,
                                        'affine': True, 'track_running_stats': True,
                                        'super_type': 'nasnet-super', 'steps': 4, 'multiplier': 4,
                                       })
            self._model_config_thin = edict({'name': 'DARTS-V1',
                                             'C': 1, 'N': 1, 'depth': 2, 'use_stem': False, 'stem_multiplier': 1,
                                             'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : space_ops,
                                             'imagenet': False,
                                             'affine': True, 'track_running_stats': True,
                                             'super_type': 'nasnet-super', 'steps': 4, 'multiplier': 4,
                                            })
        # prepare supernets with random initialization
        self._networks = []
        # self._networks_thin = []
        for _ in range(self._xargs.repeat):
            network = get_cell_based_tiny_net(self._model_config).cuda().train()
            init_model(network, xargs.init)
            self._networks.append(network)
            # network_thin = get_cell_based_tiny_net(self._model_config_thin).cuda().train()
            # init_model(network_thin, xargs.init)
            # self._networks_thin.append(network_thin)
        # prepare data samples
        self._ntk_input_data = []
        for i, (inputs, targets) in enumerate(self._loader):
            if i >= self._xargs.repeat: break
            self._ntk_input_data.append((inputs, targets))
        self._ntk_target_data = [] # for NTK kernel regression
        for i, (inputs, targets) in enumerate(self._loader_val):
            if i >= self._xargs.repeat: break
            self._ntk_target_data.append((inputs, targets))

    def _update_bad_cases(self, reward_type, reward):
        # re-set "reward_type" of bad architectures to "reward"
        for _type in self._reward_types:
            for _idx, isbad in enumerate(self._buffers_bad):
                if isbad:
                    self._buffers[_type][_idx] = reward
            for _idx, isbad in enumerate(self._buffers_bad):
                if isbad:
                    self._buffers_change[_type][_idx] = (self._buffers[_type][_idx] - self._buffers[_type][_idx-1]) / (max(self._buffers[_type][max(0, _idx+1-self._buffer_length):_idx+1]) - min(self._buffers[_type][max(0, _idx+1-self._buffer_length):_idx+1]) + 1e-6)
                    if _idx + 1 < len(self._buffers_bad):
                        self._buffers_change[_type][_idx+1] = (self._buffers[_type][_idx+1] - self._buffers[_type][_idx]) / (max(self._buffers[_type][max(0, _idx+2-self._buffer_length):_idx+2]) - min(self._buffers[_type][max(0, _idx+2-self._buffer_length):_idx+2]) + 1e-6)

    def arch_str2mask_201(self, arch_str):
        masks = [torch.ones(6, 5) * (-INF)]
        arch_str_list = np.take(arch_str.split('|'), [1, 3, 4, 6, 7, 8])
        for idx, op in enumerate(arch_str_list):
            masks[0][idx, op2index_201[op.split('~')[0]]] = 0
        return masks

    def arch_parameters2mask(self, arch_parameters: List[torch.Tensor]):
        assert isinstance(arch_parameters, list), f"arch_parameters: {arch_parameters}"
        masks = []
        if self._space_name == 'nas-bench-201':
            for _arch in arch_parameters:
                mask = torch.ones_like(_arch) * (-INF)
                for _idx, edge in enumerate(_arch):
                    mask[_idx][edge.argmax()] = 0
                masks.append(mask)
        elif self._space_name == 'darts':
            for _arch in arch_parameters:
                _arch = torch.nn.functional.softmax(_arch.detach().clone(), -1)
                mask = torch.ones_like(_arch) * (-INF)
                n = 2; start = 0
                for i in range(4):
                    end = start + n
                    edges = sorted(range(i + 2), key=lambda x: -max(_arch[start:end][x][k] for k in range(len(_arch[start:end][x]))))[:2]
                    for edge in edges:
                        # mask[edge+start, _arch[edge+start, 1:].argmax()+1] = 0
                        mask[edge+start, _arch[edge+start].argmax()] = 0
                    start = end; n += 1
                masks.append(mask)
        return masks

    def get_ntk_region_mse(self, xargs, arch_parameters, loader):
        # arch_parameters now has three dim: cell_type, edge, op
        for _r in range(self._xargs.repeat):
            self._networks[_r].set_alphas(arch_parameters)
            # self._networks_thin[_r].set_alphas(arch_parameters)

        ntks = [0]; mses = [0]; LRs = [0]
        '''
        if  'tra' in self._reward_types and 'gen' in self._reward_types:
            ntks, mses = get_ntk_n(self._ntk_input_data, self._networks, loader_val=self._ntk_target_data, train_mode=True, num_batch=1, num_classes=self._class_num)
        elif 'tra' in self._reward_types:
            ntks = get_ntk_n(self._ntk_input_data, self._networks, train_mode=True, num_batch=1, num_classes=self._class_num)
        elif 'gen' in self._reward_types:
            _, mses = get_ntk_n(self._ntk_input_data, self._networks, loader_val=self._ntk_target_data, train_mode=True, num_batch=1, num_classes=self._class_num)
        if 'exp' in self._reward_types:
            with torch.no_grad():
                region_model.reinit(models=self._networks_thin, seed=xargs.rand_seed)
                LRs = region_model.forward_batch_sample()
                region_model.clear()
        '''
        ntks, LRs, mses = get_ntk_n_zen(self._ntk_input_data, self._networks, vloader=self._ntk_target_data, train_mode=True, num_batch=1, num_classes=self._class_num)
        
        torch.cuda.empty_cache()
        return {
                "ntk": np.mean(ntks), "region": np.mean(LRs), "mse": np.mean(mses),
                "bad": np.mean(ntks)==-1 or np.mean(LRs)==-1 or np.mean(mses)==-1 # networks of bad gradients
               }

    def get_reward(self):
        _reward = 0
        if len(self._buffers[self._reward_types[0]]) <= 1:
            # dummy reward for step 0
            return 0
            return [(0, self.reward_type2index[self._reward_types[0]])]
        type_reward = [] # tuples of (type, reward)
        for _type in self._reward_types:
            var = self._buffers_change[_type][-1]
            type_reward.append((self.reward_type2index[_type], self._reward_sign[_type] * var))
        if len(type_reward) > 0:
            _reward = sum([_r for _t, _r in type_reward])
        print(type_reward)
        return _reward

    def _buffer_insert(self, results):
        if len(self._buffers[self._reward_types[0]]) == 0:
            self._buffers_bad.append(results['bad'])
            for _type in self._reward_types:
                self._buffers_change[_type].append(0)
                self._buffers[_type].append(results[_type])
        else:
            if results['bad']:
                # set ntk/mse of bad architecture as worst case in current buffer
                if 'tra' in self._reward_types: results['tra'] = max(self._buffers['tra'])
                if 'exp' in self._reward_types: results['exp'] = max(self._buffers['exp'])
                if 'gen' in self._reward_types: results['gen'] = max(self._buffers['gen'])
            else:
                if 'tra' in self._reward_types and results['tra'] > max(self._buffers['tra']):
                    self._update_bad_cases('tra', results['tra'])
                if 'exp' in self._reward_types and results['exp'] > max(self._buffers['exp']):
                    self._update_bad_cases('exp', results['exp'])
                if 'gen' in self._reward_types and results['gen'] > max(self._buffers['gen']):
                    self._update_bad_cases('gen', results['gen'])
            self._buffers_bad.append(results['bad'])
            for _type in self._reward_types:
                self._buffers[_type].append(results[_type])
                var = (self._buffers[_type][-1] - self._buffers[_type][-2]) / (max(self._buffers[_type][-self._buffer_length:]) - min(self._buffers[_type][-self._buffer_length:]) + 1e-6)
                self._buffers_change[_type].append(var)

    def step(self, arch, mask=True, verbose=False):
        if mask:
            if self._space_name == 'nas-bench-201' and isinstance(arch, str):
                arch_parameters = self.arch_str2mask_201(arch)
            else:
                arch_parameters = self.arch_parameters2mask(arch)
        else:
            # e.g. for supernet pruning, not single-path
            arch_parameters = arch
        results = self.get_ntk_region_mse(self._xargs, arch_parameters, self._loader)
        self._buffer_insert(results)
        if verbose:
            print("NTK tra buffer:", self._buffers['tra'][-self._buffer_length:])
            print("NTK tra change buffer:", self._buffers_change['tra'][-self._buffer_length:])
            print("NTK exp buffer:", self._buffers['exp'][-self._buffer_length:])
            print("NTK exp change buffer:", self._buffers_change['exp'][-self._buffer_length:])
            print("NTK gen buffer:", self._buffers['gen'][-self._buffer_length:])
            print("NTK gen change buffer:", self._buffers_change['gen'][-self._buffer_length:])
        reward = self.get_reward()
        # reward larger the better
        return reward

    def _buffer_rank_best(self):
        # return the index of the best based on rankings over three buffers
        rankings = {}
        buffers_sorted = {}
        rankings_all = []
        for _type in self._reward_types:
            buffers_sorted[_type] = sorted(self._buffers[_type], reverse=self._reward_sign[_type]==1) # by default ascending
            num_samples = len(buffers_sorted[_type])
            rankings[_type] = [ buffers_sorted[_type].index(value) for value in self._buffers[_type] ]
        for _idx in range(num_samples):
            rankings_all.append(sum([ rankings[_type][_idx] for _type in rankings.keys() ]))
        return np.argmin(rankings_all)


    def _buffer_rank_best_new(self):
        # return the index of the best based on rankings over three buffers
        rankings = {}
        buffers_sorted = {}
        rankings_all = []
        length = len(self._buffers[self._reward_types[0]])
        random_number = random.randint(0, length - 1)
        
        ori_1 = self._buffers[self._reward_types[0]][random_number]
        ori_2 = self._buffers[self._reward_types[1]][random_number]
        ori_3 = self._buffers[self._reward_types[2]][random_number]
        
        score_history = []
        for i in range(length):
            score = 0
            score += (ori_1 - self._buffers[self._reward_types[0]][i]) / abs(ori_1)
            score += (ori_2 - self._buffers[self._reward_types[1]][i]) / abs(ori_2)
            score += (ori_3 - self._buffers[self._reward_types[2]][i]) / abs(ori_3)
            score_history.append(score)
            
        return np.argmax(score_history)
        '''
        for _type in self._reward_types:
            buffers_sorted[_type] = sorted(self._buffers[_type], reverse=self._reward_sign[_type]==1) # by default ascending
            num_samples = len(buffers_sorted[_type])
            rankings[_type] = [ buffers_sorted[_type].index(value) for value in self._buffers[_type] ]
        for _idx in range(num_samples):
            rankings_all.append(sum([ rankings[_type][_idx] for _type in rankings.keys() ]))
        return np.argmin(rankings_all)
        '''